#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "verify.h"
#include "scopelock.h"
#include "poller.h"

SelectIO::SelectIO() : _highfd(-1)
{
	FD_ZERO(&_rset);
	FD_ZERO(&_wset);
	VERIFY(pthread_mutex_init(&_m, NULL) == 0);
}

SelectIO::~SelectIO() 
{
	VERIFY(pthread_mutex_destroy(&_m) == 0);
}

bool SelectIO::WatchFd(int fd, PollFlag flag)
{
	ScopeLock m(&_m);
	if (fd >= FD_SETSIZE)
		return false;
	if (fd > _highfd)
		_highfd = fd;
	if (flag == POLLER_RD) {
		FD_SET(fd, &_rset);
	} else if (flag == POLLER_WR) {
		FD_SET(fd, &_wset);
	} else if (flag == POLLER_RDWR) {
		FD_SET(fd, &_rset);
		FD_SET(fd, &_wset);
	} else {
		return false;
	}
	return true;
}

bool SelectIO::UnwatchFd(int fd, PollFlag flag)
{
	ScopeLock m(&_m);
	if (fd >= FD_SETSIZE)
		return false;
	if (flag == POLLER_RD) {
		FD_CLR(fd, &_rset);
	} else if (flag == POLLER_WR) {
		FD_CLR(fd, &_wset);
	} else if (flag == POLLER_RDWR) {
		FD_CLR(fd, &_rset);
		FD_CLR(fd, &_wset);
	} else {
		return false;
	}

	if (!FD_ISSET(fd, &_rset) && !FD_ISSET(fd, &_wset)) {
		if (_highfd == fd) {
			int newfd = 0;
			for (int fd = 0; fd <= _highfd; ++fd) {
				if (FD_ISSET(fd, &_rset) || FD_ISSET(fd, &_wset)) {
					newfd = fd;
				}
			}
			_highfd = newfd;
		}
	}

	return !FD_ISSET(fd, &_rset) && !FD_ISSET(fd, &_wset);
}

bool SelectIO::IsWatched(int fd, PollFlag flag)
{
	if (fd > FD_SETSIZE)
		return false;
	if (flag == POLLER_RD) {
		return FD_ISSET(fd, &_rset);
	} else if (flag == POLLER_WR) {
		return FD_ISSET(fd, &_wset);
	} else if (flag == POLLER_RDWR) {
		return FD_ISSET(fd, &_rset) && FD_ISSET(fd, &_wset);
	} else {
		return false;
	}
}

void SelectIO::WaitEvents(std::vector<int> *rv, std::vector<int> *wv, int timeout)
{
	fd_set trset, twset;
	int highfd;
	struct timeval tv, *ptv = NULL;
	
	{
		ScopeLock m(&_m);
		trset = _rset;
		twset = _wset;
		highfd = _highfd;
	}
	if (timeout != -1) {
		tv.tv_sec = 0;
		tv.tv_usec = timeout * 1000;
		ptv = &tv;
	}
	int nfd = select(highfd+1, &trset, &twset, NULL, ptv);
	if (nfd < 0) {
		if (errno == EINTR) {
			return ;
		}
		perror("select error");
		return ;
	}
	for (int fd = 0, cnt = 0; fd <= highfd; ++fd) {
		if (FD_ISSET(fd, &trset)) {
			rv->push_back(fd);
			if (++cnt == nfd)
				break;
		}
		if (FD_ISSET(fd, &twset)) {
			wv->push_back(fd);
			if (++cnt == nfd)
				break;
		}
	}
}

EpollIO::EpollIO(int maxfd) : _maxfd(maxfd)
{
	VERIFY((_efd = epoll_create(_maxfd)) >= 0);
	_evts = (struct epoll_event *)malloc(_maxfd * sizeof(struct epoll_event));
	VERIFY(_evts != NULL);
	_status = (int *)calloc(_maxfd, sizeof(int));
	VERIFY(_status != NULL);
}

EpollIO::~EpollIO() 
{
}

bool EpollIO::WatchFd(int fd, PollFlag flag)
{
	if (fd > _maxfd)
		return false;
	int op = _status[fd] ? EPOLL_CTL_MOD : EPOLL_CTL_ADD;
	int status = 0;
	struct epoll_event ev;
	
	//use EPOLLET ?
	ev.events = 0;
	ev.data.fd = fd;

	if (flag == POLLER_RD) {
		status |= EPOLLIN;
	} else if (flag == POLLER_WR) {
		status |= EPOLLOUT;
	} else if (flag == POLLER_RDWR) {
		status |= (EPOLLIN | EPOLLOUT);
	} else {
		return false;
	}
	if (_status[fd] == status)
		return true;
	
	ev.events |= status;
	_status[fd] |= status;

	return epoll_ctl(_efd, op, fd, &ev) == 0;
}

bool EpollIO::UnwatchFd(int fd, PollFlag flag)
{
	if (fd > _maxfd)
		return false;
	int status = 0;
	struct epoll_event ev;

	ev.events = 0;
	ev.data.fd = fd;

	if (flag == POLLER_RD) {
		status |= EPOLLIN;
	} else if (flag == POLLER_WR) {
		status |= EPOLLOUT;
	} else if (flag == POLLER_RDWR) {
		status |= EPOLLIN | EPOLLOUT;
	} else {
		return false;
	}
	
	int op = (_status[fd] & POLLER_MASK) & ~status ? EPOLL_CTL_MOD : EPOLL_CTL_DEL;
	_status[fd] = (_status[fd] & POLLER_MASK) & ~status;
	ev.events |= _status[fd];

	VERIFY(epoll_ctl(_efd, op, fd, &ev) == 0);

	return op == EPOLL_CTL_DEL;
}

bool EpollIO::IsWatched(int fd, PollFlag flag)
{
	if (fd > _maxfd)
		return false;
	if (flag == POLLER_RD) {
		return _status[fd] & EPOLLIN;
	} else if (flag == POLLER_WR) {
		return _status[fd] & EPOLLOUT;
	} else if (flag == POLLER_RDWR) {
		return _status[fd] & (EPOLLIN | EPOLLOUT);
	} else {
		return false;
	}
}

void EpollIO::WaitEvents(std::vector<int> *rv, std::vector<int> *wv, int timeout)
{
	int nfd = epoll_wait(_efd, _evts, _maxfd, timeout);
	if (nfd < 0) {
		if (errno == EINTR) {
			return ;
		}
		perror("epoll_wait error");
		return ;
	}
	for (int i = 0; i < nfd; ++i) {
		if (_evts[i].events & EPOLLIN) {
			rv->push_back(_evts[i].data.fd);
		}
		if (_evts[i].events & EPOLLOUT) {
			wv->push_back(_evts[i].data.fd);
		}
	}
}

#ifdef _COLIN_TEST_
void *thread_fun(void *data)
{
	int fd = (int)data;
	printf("entering thread!\n");
	sleep(2);
	write(fd, "a", 1);
	printf("thread done!\n");
	return 0;
}

int main(int argc, char **argv)
{
	int pipefd[2];
	
	pipe(pipefd);

	Poller *poller = new SelectIO();
	VERIFY(poller->WatchFd(pipefd[0], POLLER_RD));
	VERIFY(poller->IsWatched(pipefd[0], POLLER_RD));
	pthread_t tid;
	pthread_create(&tid, NULL, thread_fun, (void *)pipefd[1]);
	for ( ; ; ) {
		std::vector<int> readable, writeable;
		poller->WaitEvents(&readable, &writeable, 5000);
		if (readable.empty() && writeable.empty()) {
			printf("timeout !\n");
			sleep(1);
			break;
		}
		for (size_t i = 0; i < readable.size(); ++i) {
			printf("fd %d readable!\n", readable[i]);
			VERIFY(poller->UnwatchFd(readable[i], POLLER_RD));
			BUG_ON(poller->IsWatched(pipefd[0], POLLER_RD));
		}
		for (size_t i = 0; i < writeable.size(); ++i) {
			printf("fd %d writeable!\n", writeable[i]);
			VERIFY(poller->UnwatchFd(readable[i], POLLER_WR));
			BUG_ON(poller->IsWatched(pipefd[0], POLLER_WR));
		}
	}
	return 0;
}
#endif
